import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

from torchvision import transforms, datasets, models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm

from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    confusion_matrix,
    classification_report,
    roc_curve,
    precision_recall_curve
)


TRAIN_NPZ = "path to pseudolabel.npz"
TEST_FOLDER = "Path to test images"

BATCH_SIZE = 16
EPOCHS = 10
LR = 3e-5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Transform based on the selected data

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(
        [],
        []
    )
])


class NPZDataset(Dataset):

    def __init__(self,npz_file,transform):

        data = np.load(npz_file,allow_pickle=True)

        self.paths = data["combined_paths"]
        self.labels = data["combined_labels"]

        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self,idx):

        img_path = self.paths[idx]
        label = int(self.labels[idx])

        img = Image.open(img_path)

        if img.mode != "RGB":
            img = img.convert("RGB")

        img = self.transform(img)

        return img,label


def get_efficientnet():

    model = models.efficientnet_v2_s(
        weights=models.EfficientNet_V2_S_Weights.DEFAULT
    )

    model.classifier[1] = nn.Linear(
        model.classifier[1].in_features,2
    )

    return model


def get_densenet():

    model = models.densenet121(
        weights=models.DenseNet121_Weights.DEFAULT
    )

    model.classifier = nn.Linear(
        model.classifier.in_features,2
    )

    return model


def get_vit():

    model = models.vit_b_16(
        weights=models.ViT_B_16_Weights.DEFAULT
    )

    model.heads.head = nn.Linear(
        model.heads.head.in_features,2
    )

    return model


def train_model(model,train_loader,test_loader,name):

    model.to(device)

    optimizer = optim.AdamW(model.parameters(),lr=LR)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,T_max=EPOCHS
    )

    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    train_losses=[]
    train_accs=[]

    best_loss=float("inf")
    patience=4
    patience_counter=0

    for epoch in range(EPOCHS):

        model.train()

        running_loss=0
        correct=0
        total=0

        for imgs,labels in tqdm(train_loader,desc=f"{name} Epoch {epoch+1}"):

            imgs,labels=imgs.to(device),labels.to(device)

            optimizer.zero_grad()

            outputs=model(imgs)

            loss=criterion(outputs,labels)

            loss.backward()

            optimizer.step()

            preds=torch.argmax(outputs,dim=1)

            running_loss+=loss.item()

            correct+=(preds==labels).sum().item()
            total+=labels.size(0)

        scheduler.step()

        epoch_loss=running_loss/len(train_loader)
        epoch_acc=correct/total

        train_losses.append(epoch_loss)
        train_accs.append(epoch_acc)

        print(f"{name} Epoch {epoch+1} | Loss: {epoch_loss:.4f} | Accuracy: {epoch_acc:.4f}")

        # Early stopping
        if epoch_loss < best_loss - 1e-4:
            best_loss=epoch_loss
            patience_counter=0
        else:
            patience_counter+=1

        if patience_counter>=patience:
            print("Early stopping triggered")
            break

    plot_training_curves(train_losses,train_accs,name)

    evaluate_model(model,test_loader,name)

def evaluate_model(model,loader,name):

    model.eval()

    preds=[]
    probs=[]
    labels_all=[]

    with torch.no_grad():

        for imgs,labels in loader:

            imgs=imgs.to(device)

            outputs=model(imgs)

            prob=torch.softmax(outputs,dim=1)[:,1]
            pred=torch.argmax(outputs,dim=1)

            preds.extend(pred.cpu().numpy())
            probs.extend(prob.cpu().numpy())
            labels_all.extend(labels.numpy())

    acc=accuracy_score(labels_all,preds)
    prec=precision_score(labels_all,preds)
    rec=recall_score(labels_all,preds)
    f1=f1_score(labels_all,preds)
    auc=roc_auc_score(labels_all,probs)

    cm=confusion_matrix(labels_all,preds)

    tn,fp,fn,tp=cm.ravel()

    sensitivity=tp/(tp+fn)
    specificity=tn/(tn+fp)

    print("\n",name)
    print("Accuracy:",round(acc,4))
    print("Precision:",round(prec,4))
    print("Recall:",round(rec,4))
    print("F1:",round(f1,4))
    print("AUC:",round(auc,4))
    print("Sensitivity:",round(sensitivity,4))
    print("Specificity:",round(specificity,4))

    print("\nClassification Report")
    print(classification_report(labels_all,preds,digits=4))

    plot_confusion_matrix(cm,name)
    plot_roc(labels_all,probs,name)
    plot_pr(labels_all,probs,name)


def plot_training_curves(losses,accs,name):

    epochs=range(1,len(losses)+1)

    plt.figure(figsize=(10,4))

    # LOSS
    plt.subplot(1,2,1)
    plt.plot(epochs,losses,marker='o')
    plt.title(name+" Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.ylim(0,max(losses)+0.1)

    # ACCURACY
    plt.subplot(1,2,2)
    acc_percent=[a*100 for a in accs]

    plt.plot(epochs,acc_percent,marker='o')
    plt.title(name+" Training Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.ylim(0,100)

    plt.tight_layout()
    plt.show()


def plot_confusion_matrix(cm,name):

    plt.figure(figsize=(5,5))

    plt.imshow(cm,cmap="Blues")
    plt.title(name+" Confusion Matrix")
    plt.colorbar()

    classes=["Class0","Class1"]

    plt.xticks(np.arange(len(classes)),classes)
    plt.yticks(np.arange(len(classes)),classes)

    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j,i,str(cm[i,j]),
                     ha="center",
                     va="center",
                     color="black",
                     fontsize=12)

    plt.xlabel("Predicted")
    plt.ylabel("True")

    plt.tight_layout()
    plt.show()


def plot_roc(labels,probs,name):

    fpr,tpr,_=roc_curve(labels,probs)

    plt.figure()

    plt.plot(fpr,tpr,label="ROC")
    plt.plot([0,1],[0,1],'--',label="Random")

    plt.xlim(0,1)
    plt.ylim(0,1)

    plt.title(name+" ROC Curve")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")

    plt.legend()

    plt.show()


def plot_pr(labels,probs,name):

    precision,recall,_=precision_recall_curve(labels,probs)

    plt.figure()

    plt.plot(recall,precision)

    plt.xlim(0,1)
    plt.ylim(0,1)

    plt.title(name+" Precision Recall Curve")

    plt.xlabel("Recall")
    plt.ylabel("Precision")

    plt.show()


def main():

    train_dataset=NPZDataset(TRAIN_NPZ,transform)

    train_loader=DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=0
    )

    test_dataset=datasets.ImageFolder(
        TEST_FOLDER,
        transform=transform
    )

    test_loader=DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=0
    )

    train_model(get_efficientnet(),train_loader,test_loader,"EfficientNetV2")

    train_model(get_densenet(),train_loader,test_loader,"DenseNet121")

    train_model(get_vit(),train_loader,test_loader,"ViT-B16")


if __name__=="__main__":
    main()